#! /usr/bin/env python
## This file is part of biopy.
## Copyright (C) 2013 Joseph Heled
## Author: Joseph Heled <jheled@gmail.com>
## See the files gpl.txt and lgpl.txt for copying conditions.

from __future__ import division

import argparse, sys, os.path

parser = argparse.ArgumentParser(description= """ %(prog)s [OPTIONS] in-file""")

parser.add_argument("-p", "--progress", dest="progress",
                    help="send progress messages to terminal (standard error)",
                    action='count', default = 0)

parser.add_argument("--derep", default = None, action = "store_true", 
                    help="""eliminate exact duplicates, add size information to name.""" +
                    """ output sequences are sorted by decreasing abundance (size).""")

parser.add_argument("--declutter", metavar="TH(,TH1,TH2,...)",
                    default = None,
                    help="""partition sequences into a hierarchical collection of islands.""")

parser.add_argument("--forest", action="store_true", default = False,
                    help="""build UPGMA trees for all islands.""")

parser.add_argument("--stitch", metavar="TH", default = None, type = float,
                    help="""Merge islands into larger islands, up to level TH.""")

parser.add_argument("--shave", metavar="TH", default = None, type = float,
                    help="""Cluster into OTUS at the TH level and generate a """
                    """consensus sequence for each OTU.""")

parser.add_argument("--cluster", metavar="TH", default = None, type = float,
                    help="""Cluster into OTUS at the TH level and out a file """
                    """with the clusters.""")

parser.add_argument("--max-single-tree", dest="maxClade", metavar="N", default = 6000, type = int,
                    help="""upper limit on the size of single UPGMA trees to build. """
                    """ This is only a guideline which and not an absolute"""
                    """ guarantee. It is advisable that system"""
                    """ resources are sufficient to build trees twice or three"""
                    """ times as large (default %(default)s).""")

parser.add_argument("--terminate-search-count", dest="failsTH", type = int,
                    metavar="N", default = 20,
                    help="""Search for matching sequence during declutter are terminated after N"""
                    """ failures. Increasing the limit increases the search sensitivity with a"""
                    """ running time cost. You probably want to keep this under 40"""
                    """ (default %(default)s).""")
                    
parser.add_argument("--stitch-parameters", dest="stitchArgs", metavar="VALUES", default = None,
                    help="""Change those only if you really know what you are
                    doing. Comma separated format: MaxReps,MaxConsReps,NoConsTH,RefinmentFactor,RefinmentLimit""")

# parser.add_argument("--verbose", action="store_true", default = False)

#parser.add_argument("--compress", metavar='N', type = int, default = -1,
#                    help="""compression level for distance files (%(default)d for no compression)""")

parser.add_argument("--save-distances", dest="savedists", metavar="WHAT",
                    choices = ["no", "plain", "compressed-fast", "compressed-normal",
                             "compressed-best"], default = "compressed-fast",
                    help="""no, plain (no compression),compressed-fast/normal/best.""")

parser.add_argument("--overwrite", action="store_true", default = False,
                    help="""blindly trash existing files.""")

parser.add_argument("--match-scores", dest="matchScores",  metavar="M,X,G,E,F",
                    default="10,-5,-6,-6,1", help="""Alignment scoring parameters, a comma separated
                    list of match,mismatch,gap-open,gap-extendfree-ends.""")

parser.add_argument("--use-sequence-identity", dest = "JCcorrection",
                    action="store_false", default = True,
                    help="use uncorrected sequence identity. by default use Jukes-Cantor correction.")

parser.add_argument("--location-from-name", dest="locfromname", metavar = "RE", default = None,
                    help=""" A regular expression extracting a location from the sequence name."""
                    """ For example: '\|([^;]*)' (remember the ';size' added by dereping. """)

parser.add_argument("--sequences-extension", dest="seqsext", metavar="STRING", default = "fasta",
                    help="""%(default)s""")

parser.add_argument("--derep-extension", dest="derepext", metavar="STRING", default = "derep",
                    help="""String used to construct the de-replication output
                    file name. %(default)s""")

parser.add_argument("--declutter-extension", dest="declext", metavar="STRING", default = "decl",
                    help="""String used to construct the declutter output file
                    name %(default)s""")

parser.add_argument("--treelogs-extension", dest="treelogext", metavar="STRING", default = "trees",
                    help="""String used to construct the output file name from
                    decluttering, which is a NEXUS file containing trees %(default)s""")

parser.add_argument("--clusters-extension", dest="clustersext", metavar="STRING", default = "clusters",
                    help="""String used to construct the output file name from
                    the cluster option %(default)s""")

parser.add_argument('infile', metavar='FILE', help="Sequences (FASTA) or NEXUS")

options = parser.parse_args()

from itertools import count
from collections import defaultdict

from biopy.genericutils import tohms, fileFromName
from biopy import otus, align, __version__
from biopy.bioutils import readFasta
from biopy.treeutils import TreeLogger, CAhelper
from biopy.INexus import INexus

verbose = options.progress > 1
fverbose = sys.stderr if verbose else None
progress = options.progress > 0
nexusLogExt = "trees.nex"

# main: basename + .fasta
# derep: basename + .derep.fasta
# declutter at X:  "basename + .derep.decl_pX.trees"  or "basename + .decl_pX.trees"
# forest at X:     "basename + .pX.trees"

# can have derep+declutter+forest
#          declutter+forest
#          forest (on decluttered file, output meaning unclear otherwise)

if options.forest :
  if options.derep is not None and options.declutter is None :
    print >> sys.stderr, "Confusing order, My Lord."
    sys.exit(1)

def fatalError(msg) :
  print >> sys.stderr, msg
  sys.exit(1)
  
def progressMessage(msg, stayOnLine = False) :
  if stayOnLine:
    print >> sys.stderr, msg,
    sys.stderr.flush()
  else :
    print >> sys.stderr, msg

def noOverwrite(fname) :
  if not options.overwrite:
    if os.path.exists(fname) :
      print >> sys.stderr, "Cowardly refusing to overwrite", fname
      sys.exit(1)


allSeqs = []
infile = options.infile

if options.derep:
  try :
    if not os.path.exists(infile) :
      infile = infile + os.path.extsep + options.seqsext
      if not os.path.exists(infile) :
        fatalError("cant find input file (%s)" % infile)

    # Remove .fasta if suffix
    base,ext = os.path.splitext(infile)
    if ext != os.path.extsep + options.seqsext :
      base = infile
    derepName = os.path.extsep.join([base, options.derepext, options.seqsext])

    noOverwrite(derepName)
      
    allSeqs = list(readFasta(fileFromName(infile), stripGaps = True))
    # for next stage, if any
    infile = derepName
    
    if progress:
      progressMessage("finding duplicates ..., ",True)

    dups = otus.findDuplicates([x[1] for x in allSeqs])

    derepFile = file(derepName, 'w')

    if progress:
      progressMessage("writing output ..., ",True)

    # Sequences after dereping
    remaining = []
    seen = set()
    for d in sorted(dups, key = list.__len__, reverse=1) :
      rep = allSeqs[d[0]]
      h,s = rep[0] + (';size=%d;' % len(d)),rep[1]
      remaining.append((h,s))
      print >> derepFile, "#D ",'\t'.join([allSeqs[x][0] for x in d])
      print >> derepFile, h
      print >> derepFile, s
      seen.update(d)

    for i,rep in enumerate(allSeqs) :
      if i not in seen :
        h,s = rep[0] + ';size=1;',rep[1]
        remaining.append((h,s))
        print >> derepFile, "#D ",h
        print >> derepFile, h
        print >> derepFile, s
        
    derepFile.close()
    
    if progress:
      progressMessage("done.")

    allSeqs = remaining
  except Exception,e:
    print >> sys.stderr, "Error:", e.message
    sys.exit(1)

try :
  matchScores = align.parseScores(options.matchScores)
except RuntimeError,e:
  print >> sys.stderr, "**Error in alignment scores:",e.message
  sys.exit(1)

if options.declutter is not None:
  try :
    ## if options.declutter == "auto" :
    ##   ths = None
    ## else :
    ths = sorted([float(x) for x in options.declutter.split(',')], reverse=1)

    failsTH = max(options.failsTH, 20)
    
    if not allSeqs:
      if not os.path.exists(infile) :
        d = os.path.extsep.join([infile, options.derepext, options.seqsext])
        if os.path.exists(d) :
          infile = d
        else :
          d = os.path.extsep.join([infile, options.seqsext])
          if not os.path.exists(d) :
            fatalError("cant find input file for declutter (%s)" % infile)
          infile = d
      if progress:
        progressMessage("reading sequences from %s" % infile)
      allSeqs = list(readFasta(file(infile), stripGaps=True))
      
    nameparts = infile.split(os.path.extsep)
    if nameparts[-1] == options.seqsext :
      nameparts.pop(-1)
    namebase = os.path.extsep.join(nameparts)

    s = '_'.join([otus.thstr(th) for th in ths]) if ths is not None else "auto"
    declName = os.path.extsep.join([namebase, "%s_%s" % (options.declext,s),
                                    options.treelogext])
    #declName = "%s.declt_%s.%s" % (namebase,,nexusLogExt)
    noOverwrite(declName)
    # for next stage
    infile = declName
      
    if progress:
      progressMessage("declutter at %s" % options.declutter + (" ...," if not verbose else ""), not verbose)

    #singles,pairs,grps = otus.deClutter([x[1] for x in allSeqs], th,
    #import pdb; pdb.set_trace()

    brk = otus.deClutterDown([x[1] for x in allSeqs], ths, options.maxClade,
                              correction = options.JCcorrection,
                              failsTH = failsTH, matchScores = matchScores, 
                              verbose = fverbose)
    #partition = [[x] for x in singles] + [x[:2] for x in pairs] + grps
    trs = otus.declutterToTrees(brk, ths)
    
    tlog = TreeLogger(declName, labels = zip(count(), [x[0][1:] for x in allSeqs]),
                      argv = sys.argv, version = __version__,
                      overwrite = options.overwrite)
    for t,nm in trs:
      tlog.outTree(t, name = nm)

    tlog.close()

    if progress:
      from biopy.otus import estimateNA, alCellsPerSecond
      from biopy.parseNewick import parseNewick
      tr = []
      for t,nm in trs :
        t = parseNewick(t)
        t.name = nm
        tr.append(t)
        
      na = estimateNA(tr, 2)
      avgSeqLen = sum([len(x[1]) for x in allSeqs])/len(allSeqs)
      tx = [tohms((x * avgSeqLen**2)/alCellsPerSecond) for x in na]
      progressMessage("Rough build time %s - %s (%d-%d)," %
                      (tx[0], tx[1], na[0], na[1]), not verbose)
      progressMessage("Declutter done." if verbose else "done.")
    
  except Exception,e:
    #import traceback
    #traceback.print_stack()
    print >> sys.stderr, "Error:", e.message
    sys.exit(1)

if options.savedists != "no" :
  ## seems like the gzip module does not respect compression level
  compressionLevel = (0,1,6,9)[["plain", "compressed-fast", "compressed-normal",
                      "compressed-best"].index(options.savedists)]
else :
  compressionLevel = None
  
if options.forest:
 try :
    thdecl = 0 ; #float(options.forest) if options.forest != "auto" else 0

    if not os.path.exists(infile) :
      fatalError("cant find input file for forest (%s)" % infile)

    nameparts = infile.split(os.path.extsep)
    if nameparts[-1] == options.treelogext :
      nameparts.pop(-1)
    ths = []
    if nameparts[-1].startswith(options.declext + '_') :
      thparts = nameparts[-1].split('_')[1:]

      assert thparts[0] != "auto", "no auto support yet"
      if thparts[0] == "auto" :
        ths = None
        thdecl = 0
      else :
        for x in thparts:
          if x[0] == 'p' and x[1:].isdigit() :
            ths.insert(0, float('.' + x[1:]))
          else :
            fatalError("error")
      
      ## p = nameparts[-1][len(options.declext + '_'):]
      ## if not p[0] == 'p' and p[1].isdigit() :
      ##   fatalError("Invalid declutter name", infile)
      ## thdecl = float("." + p[1:])
      thdecl = min(ths)
      
      if ths is not None and not all([thdecl <= x for x in ths]) :
        fatalError("Can't build %g forest from %s declutter" % (thdecl,nameparts[-1]))
      nameparts.pop(-1)

    if not allSeqs :
      base = os.path.extsep.join(nameparts)
      seqsfile = base + os.path.extsep + options.seqsext
      if not os.path.exists(seqsfile) :
        seqsfile = base
        if not os.path.exists(seqsfile) :
          fatalError("Unable to located sequences for " + infile)

      if progress:
        progressMessage("reading sequences from %s" % seqsfile)
      allSeqs = list(readFasta(file(seqsfile)))

    if progress:
      progressMessage("Building trees for %s forest" % options.forest)

    forestLevel = otus.thstr(thdecl) if thdecl else "auto"
    forestBase = os.path.extsep.join(nameparts + [forestLevel])
    forestName = forestBase + os.path.extsep + options.treelogext

    tlog = TreeLogger(forestName, labels = zip(count(), [x[0][1:] for x in allSeqs]),
                      argv = sys.argv, version = __version__,
                      overwrite = options.overwrite)

    nameToSeqIndex = dict([(h[1:],k) for (h,s),k in zip(allSeqs, count())])
    def getSize(h) :
      if h[-1] == ';' :
        i = h[:-1].rfind(';')
        if i >= 0 and h[i+1:i+6] == "size=":
          return int(h[i+6:-1])
      return 1
          
    #int(h.split(';')[-2].split('=')[1]) 
    indexToSize = dict([(k,getSize(h)) for (h,s),k in zip(allSeqs, count())])
    
    def getSeqs(tax) :
      try :
        i = [nameToSeqIndex[x] for x in tax]
        return [allSeqs[x][1] for x in i], [indexToSize[x] for x in i]
      except:
        fatalError("mismatch between trees and sequences numbering")

    cnt = 0
    for tree in INexus(simple=1).read(infile) :
      tax = [x.strip("'") for x in tree.get_taxa()]
      assert all([x in nameToSeqIndex for x in tax]), \
             "Tree taxon not found in fasta file:" + ' '.join(['!'+x+'!' for x in tax if x not in nameToSeqIndex])
      itx = [str(nameToSeqIndex[x]) for x in tax]
      if len(tax) > 1:
        seqs,wts = getSeqs(tax)
        dsFile = forestBase + '_c%d' % cnt + '.dists'
        if os.path.exists(dsFile) or os.path.exists(dsFile + '.gz') :
          ds,lb = otus.getDistanceMatrix(dsFile)
          assert tax == lb, "Invalid distances in %s." % dsFile
          tr = otus.treeFromDists(ds, tax = itx, weights = wts, asString = True)
        else :
          tr,ds = otus.treeFromSeqs(seqs, itx, weights = wts,
                                    matchScores = matchScores,
                                    correction = options.JCcorrection,
                                    saveDists = None if compressionLevel is None
                                    else (dsFile,tax,compressionLevel),
                                    asString = True, verbose = fverbose)
          #if compressionLevel is not None :
          #  otus.saveDistancesMatrix(dsFile, ds, tax, compress = compressionLevel)
      else :
        tr = '(' + itx[0] + ')'
      tlog.outTree(tr, name = tree.name)
      tlog.outFile.flush()
      cnt += 1
      
    tlog.close()

    # for next stage (if any)
    infile = forestName
    
    if progress:
      progressMessage("Done building trees for %s forest." % options.forest)
 except Exception,e:
   #import traceback
   #traceback.print_stack()
   print >> sys.stderr, "Error:", e.message[:1000]
   sys.exit(1)


if options.stitch is not None:
  try :
    thstitch = float(options.stitch)

    if not os.path.exists(infile) :
      fatalError("cant find input file for stitching (%s)" % infile)

    a = [None]*5
    if options.stitchArgs is not None :
      a = options.stitchArgs.split(',')
      if len(a) != 5 :
         fatalError("expecting 5 comma separated values for --stitch-arguments")
         
    maxReps = int(a[0] or 20) 
    maxConsReps = int(a[1] or 100)
    noConsTH = float(a[2] or 0.02)
    refinmentFactor = 1 + float(a[3] or 0.1)
    refinmentLimit = float(a[4] or .15)
    if not ( 0 < maxReps and 0 < maxConsReps and 0 < noConsTH < 0.05
             and 1 <= refinmentFactor and 0 < refinmentLimit ) :
      fatalError("stitch-arguments dont make sense")
      
    nameparts = infile.split(os.path.extsep)
    if nameparts[-1] == options.treelogext :
      nameparts.pop(-1)

    plast = nameparts[-1]
    if (plast[0] == 'p' and plast[1:].isdigit()) or plast == 'auto' :
      #thfor = None
      nameparts.pop(-1)

    base = os.path.extsep.join(nameparts + ([otus.thstr(thstitch)] if thstitch < 1 else []))
    finalLogName = base + os.path.extsep + options.treelogext
    noOverwrite(finalLogName)
    
    if not allSeqs :
      base = os.path.extsep.join(nameparts)
      seqsfile = base + os.path.extsep + options.seqsext
      if not os.path.exists(seqsfile) :
        seqsfile = base
        if not os.path.exists(seqsfile) :
          fatalError("Unable to located sequences for " + infile)

      if progress:
        progressMessage("reading sequences from %s" % seqsfile)
      allSeqs = list(readFasta(file(seqsfile)))

    if progress:
      progressMessage("stitching trees for %s forest" % options.stitch)

    trees = []
    levels = set()
    for tree in INexus(simple=1).read(infile) :
      assert tree.name.startswith('cluster_')      
      levels.add(otus.thfromstr(tree.name.split('_')[-1]))
      
      trees.append(tree)

    levels = list(sorted(levels))

    nameToSeqIndex = dict([(h[1:],k) for (h,s),k in zip(allSeqs, count())])
    ## indexToSize = dict([(k,int(h.split(';')[-2].split('=')[1]))
    ##                     for (h,s),k in zip(allSeqs, count())])
    
    getseq = lambda x : allSeqs[nameToSeqIndex[x.strip("'")]][1]

    #tlog = TreeLogger(os.path.extsep.join(nameparts) + ".intrim.trees",
    #                  labels = zip(count(), [x[0][1:] for x in allSeqs]), overwrite=1) ; # intrim

    intrim = False;  # For debugging
    if intrim:
      tlog = TreeLogger(os.path.extsep.join(nameparts) + ".intrim.trees", overwrite=1) ; # intrim

    #    import pdb; pdb.set_trace()

    # Put trees on a stack and assemble a group of e-islands into one l-island if
    # island level (l) is not greater than stitch level. The group of e-islands
    # is in one block on the top, since they were generated in this order by the
    # decluttering. Scan trees from last to first, so the zero index tree marks
    # the end of the group.
    
    stack = []
    for tree in reversed(trees) :
      treeLevel = tree.name.split('_')[1:]
      #print treeLevel
      th = otus.thfromstr(treeLevel[-1])

      # Assemble group when a group is on the stack. This loop may repeat
      # several times when the tree pops several levels.
      #
      # Needs stitching, size > 1 and all group on stack
      while th < thstitch and len(treeLevel) > 2 and int(treeLevel[-2]) == 0 :
        # move all e-trees from stack to 'c'
        c = [tree]
        while len(stack) and stack[-1][0][:-2] == treeLevel[:-2] :
          l,t = stack.pop(-1)
          c.append(t)
        
        th = otus.thfromstr(treeLevel[-1])
        thTo = otus.thfromstr(treeLevel[-3])
        #thTo = levels[levels.index(th) + 1]
        
        atree = otus.assembleTree(c, th, thTo, getseq, verbose = fverbose,
                                  nMaxReps = maxReps, maxPerCons = maxConsReps,
                                  lowDiversity = noConsTH, refineFactor = refinmentFactor,
                                  refineUpperLimit = refinmentLimit)

        # Set for loop condition: tree level is reduced by one level, tree name
        # and 'th' set to reflect new level
        
        treeLevel = treeLevel[:-2] # + [otus.thstr(thTo)]
        atree.name = "cluster_" + '_'.join(treeLevel)
#        for x in atree.get_terminals() :
#          d = atree.node(x).data
#          d.taxon = str(nameToSeqIndex[d.taxon.strip("'")])
          
        th = thTo
        tree = atree

        if intrim:
          tlog.outTree(atree, name = atree.name) ; # intrim
          tlog.outFile.flush() ; # intrim

      # Push back assembled tree  
      stack.append([treeLevel,tree])

    asis = []
    assemble = []
    ath = None
    for c,tree in stack :
      treeLevel = tree.name.split('_')[1:]
      th = otus.thfromstr(treeLevel[-1])
      if th >= thstitch :
        asis.append(tree)
        
        for x in tree.get_terminals() :
          d = tree.node(x).data
          d.taxon = str(nameToSeqIndex[d.taxon.strip("'")])
        
      else :
        assert ath == None or ath == th
        ath = th
        assemble.append(tree)

    #import pdb; pdb.set_trace()
    
    if len(assemble) :
      atree = otus.assembleTree(assemble, ath, thstitch, getseq, verbose = fverbose)
        
      if len(asis) > 0 :
        p = assemble[0].name.split('_')
        n = '_'.join(p[:-2]) # + ([otus.thstr(thstitch)] if thstitch < 1 else []) )
        atree.name = n
      else :
        atree.name = "tree"

      if intrim:
        tlog.outTree(str(atree), name = atree.name) ; # intrim
        tlog.outFile.flush() ; # intrim
      
      for x in atree.get_terminals() :
        d = atree.node(x).data
        d.taxon = str(nameToSeqIndex[d.taxon.strip("'")])
        
      # appending at end guaranteed to be valid ordering for a possible future stitch?  
      asis.append(atree)
      
    if intrim:
      tlog.close() ; # intrim
    
    #base = os.path.extsep.join(nameparts + ([otus.thstr(thstitch)] if thstitch < 1 else []))
    #name = base + os.path.extsep + options.treelogext
    tlog = TreeLogger(finalLogName, labels = zip(count(), [x[0][1:] for x in allSeqs]),
                      argv = sys.argv, version = __version__,
                      overwrite = options.overwrite)
    # We iterate in reverse, reverse again to regain order
    for t in reversed(asis) :
      tlog.outTree(t, name = t.name)
      
    tlog.close()
    
  except Exception,e:
   #import pdb; pdb.set_trace()
   #import traceback
   #traceback.print_stack()
   print >> sys.stderr, "Error:", e.message[:1000]
   sys.exit(1)
    
  
if options.shave is not None:
  try :
    thshave = float(options.shave)

    if not os.path.exists(infile) :
      fatalError("cant find input file (%s)" % infile)

    nameparts = infile.split(os.path.extsep)
    if nameparts[-1] == options.treelogext :
      nameparts.pop(-1)

    if not allSeqs :
      base = os.path.extsep.join(nameparts)
      seqsfile = base + os.path.extsep + options.seqsext
      if not os.path.exists(seqsfile) :
        seqsfile = base
        if not os.path.exists(seqsfile) :
          fatalError("Unable to located sequences for " + infile)

      if progress:
        progressMessage("reading sequences from %s ...," % seqsfile, True)
      allSeqs = list(readFasta(file(seqsfile)))

    nameToSeqIndex = dict([(h[1:],k) for (h,s),k in zip(allSeqs, count())])
    getseq = lambda x : allSeqs[nameToSeqIndex[x.strip("'")]][1]

    if progress:
      progressMessage("shaving tree at %s ...," % options.shave, not fverbose)
      if fverbose :
        import time
      
    for tree in INexus(simple=1).read(infile) :
      seqsOutName = tree.name + ".shave_" + otus.thstr(thshave) + \
                    os.path.extsep + options.seqsext
      noOverwrite(seqsOutName)
      sqout = file(seqsOutName, 'w')
      
      ca = CAhelper(tree)
      clades = otus.clusterFromTree(tree, thshave/2, ca)
      for node in clades:
        sqs = [getseq(x.data.taxon) for x in node.data.terms]
        if fverbose and len(sqs) > 100:
          print >> fverbose, len(sqs),"sequences (%.5g)" % node.data.rh,time.strftime("%T"),
          tnow = time.clock()
          fverbose.flush()
          
        #import pdb; pdb.set_trace()
        if len(sqs) > 1 :
          cs = otus.doTheCons(sqs, node.data.rh)
        else :
          cs = sqs[0]
          
        if fverbose and len(sqs) > 100:
          print >> fverbose, tohms(time.clock() - tnow)
          
        if options.locfromname:
          import re
          comun = defaultdict( lambda : 0 )
        sz = 0
        for n in node.data.terms:
          tax = n.data.taxon.strip("'")
          i = tax.find(";size=")
          if i >= 0 :
            e = tax[i+6:].index(';')
            isz = int(tax[i+6:i+6+e]); assert isz > 0, (tax, isz)
          else :
            isz = 1
          sz += isz
          if options.locfromname:
            r = re.search(options.locfromname, tax)
            if r :
              comun[r.group(1)] += isz
        
        if len(node.data.terms) == 1 :
          nm = node.data.taxon.strip("'")
        else :
          ch = [tree.node(x) for x in node.succ]
          nm = '$'.join([x.data.terms[0].data.taxon.strip("'") for x in ch])

        #import pdb; pdb.set_trace()
        if len(comun) :
          nm = nm + '$' + ':'.join(["%s=%d" % (n,s) for n,s in comun.iteritems()])
          
        nm = nm + ('$;size=%d;' % (sz if sz > 0 else len(node.data.terms)))
        
        print >> sqout, '>' + nm
        print >> sqout, cs
        
      sqout.close()

    progressMessage("")
  
  except Exception,e:
    #import traceback
    #traceback.print_stack()
    print >> sys.stderr, "Error:", e.message[:1000]
    sys.exit(1)

if options.cluster is not None:
  try :
    if not os.path.exists(infile) :
      fatalError("cant find input file (%s)" % infile)

    thclust = float(options.cluster)

    nameparts = infile.split(os.path.extsep)
    if nameparts[-1] == options.treelogext :
      nameparts.pop(-1)
    if nameparts[-1] == options.derepext :
      nameparts.pop(-1)

    while nameparts[-1][0] == 'p' and nameparts[-1][1:].isdigit() :
      nameparts.pop(-1)
    base = os.path.extsep.join(nameparts)
    
    seqsfile = base + os.path.extsep + options.seqsext

    if not os.path.exists(seqsfile) :
      fatalError("Unable to located original sequences for " + infile)
    #allSeqs = dict([(h[1:],b) for h,b in readFasta(file(seqsfile))])
    if progress:
      progressMessage("reading sequences from %s ...," % seqsfile, True)
    allSeqs = [(h[1:],b) for h,b in readFasta(file(seqsfile))]

    byseq = defaultdict(lambda : [])
    for n,b in allSeqs :
      byseq[b].append(n)
    byname = dict()
    for b in byseq:
      for n in byseq[b]:
        byname[n] = byseq[b] 
    del allSeqs, byseq
    
    def getMapped(n) :
      #print n, byname.keys()[:2]
      n = norig = n.strip("'")
      if ';' in n :
        ns = n.split(';')
        if not ns[-1]:
          ns = ns[:-1]
        if len(ns) > 1 and ns[-1].startswith('size=') :
          ns = ns[:-1]
          n = ';'.join(ns)

      return byname.get(n) or byname[norig]
    
    # fasta file (derep probably)
    # full file if derep

    if progress:
      progressMessage("clustering ...,", True)

    ucOutName = base + os.path.extsep + otus.thstr(thclust) + os.path.extsep + options.clustersext
    noOverwrite(ucOutName)
    ucout = file(ucOutName, 'w')
    
    for tree in INexus(simple=1).read(infile) :
      ca = CAhelper(tree)
      clades = otus.clusterFromTree(tree, thclust/2, ca)
      for node in clades:
        #clu = [getMapped(x.data.taxon) for x in node.data.terms])
        r = []
        for x in node.data.terms :
          r.extend(getMapped(x.data.taxon))

        print >> ucout, '\t'.join(r)
        
    ucout.close()

    if progress:
      progressMessage("done.")

  except Exception,e:
    #import traceback
    #traceback.print_stack()
    print >> sys.stderr, "Error:", e.message[:1000]
    sys.exit(1)

#  LocalWords:  UPGMA declutter
